作者:极神bd韵 | 来源:互联网 | 2023-08-19 12:50
篇首语:本文由编程笔记#小编为大家整理,主要介绍了目标检测YOLOv5:添加漏检率和虚检率输出相关的知识,希望对你有一定的参考价值。
前言
在目标检测领域,衡量一个模型的优劣的指标往往是mAP,然而实际工程中,有时候更倾向于看漏检率和虚检率。YOLOv5的原始代码并没有这两个指标的输出,因此我想利用原始代码的混淆矩阵,输出这两个指标数值。
指标解释
漏检即原本有目标存在却没有检测出来,换句话说就是原本是目标却检测成了背景。
虚检(虚警)即原本没有目标却误认为有目标,换句话说就是原本是背景却检测成了目标。
首先来看YOLOv5原本输出的混淆矩阵,图中灰色覆盖的地方是原本输出的各类别,也就是输出的正例,最后一行和一列是背景类。
列是模型预测的结果,行是标签的真实结果。可以看到最后一行出现数值,表示出现了漏检;最后一列出现数值,则表示出现了虚检。
代码改进
现在来看YOLOv5输出的混淆矩阵代码部分,代码主要位于metrics.py
的ConfusionMatrix
类中。
class ConfusionMatrix:
def __init__(self, nc, conf=0.25, iou_thres=0.45):
"""
params nc: 数据集类别个数
params conf: 预测框置信度阈值
Params iou_thres: iou阈值
"""
self.matrix = np.zeros((nc + 1, nc + 1))
self.nc = nc
self.conf = conf
self.iou_thres = iou_thres
self.lou = 0
self.total = 0
self.xu = 0
def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
"""
detections = detections[detections[:, 4] > self.conf]
gt_classes = labels[:, 0].int()
detection_classes = detections[:, 5].int()
iou = general.box_iou(labels[:, 1:], detections[:, :4])
x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
else:
matches = np.zeros((0, 3))
n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(np.int16)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:
self.matrix[gc, detection_classes[m1[j]]] += 1
else:
self.matrix[self.nc, gc] += 1
if n:
for i, dc in enumerate(detection_classes):
if not any(m1 == i):
self.matrix[dc, self.nc] += 1
self.lou = sum(self.matrix[-1, :])
self.total = sum(sum(self.matrix))
self.xu = sum(self.matrix[:, -1])
def matrix(self):
return self.matrix
def plot(self, save_dir='', names=()):
try:
import seaborn as sn
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6)
array[array < 0.005] &#61; np.nan
fig &#61; plt.figure(figsize&#61;(12, 9), tight_layout&#61;True)
sn.set(font_scale&#61;1.0 if self.nc < 50 else 0.8)
labels &#61; (0 < len(names) < 99) and len(names) &#61;&#61; self.nc
sn.heatmap(array, annot&#61;self.nc < 30, annot_kws&#61;"size": 8, cmap&#61;&#39;Blues&#39;, fmt&#61;&#39;.2f&#39;, square&#61;True,
xticklabels&#61;names &#43; [&#39;background FP&#39;] if labels else "auto",
yticklabels&#61;names &#43; [&#39;background FN&#39;] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel(&#39;True&#39;)
fig.axes[0].set_ylabel(&#39;Predicted&#39;)
fig.savefig(Path(save_dir) / &#39;confusion_matrix.png&#39;, dpi&#61;250)
except Exception as e:
pass
def print(self):
for i in range(self.nc &#43; 1):
print(&#39; &#39;.join(map(str, self.matrix[i])))
阅读代码可以发现&#xff0c;混淆矩阵再绘制时对每一列单独进行了归一化&#xff0c;那么再绘制之前&#xff0c;混淆矩阵存储了每一个预测结果和真实结果的数目。
于是我添加了三个属性self.lou
、self.total &#61; 0
、self.xu &#61; 0
&#xff0c;分别统计漏检目标数目&#xff0c;总目标数目和虚检目标数目。
漏检目标数目只需要将混淆矩阵最后一行相加&#xff0c;虚检目标数目只需要将混淆矩阵最后一列相加&#xff0c;总目标数目则将混淆矩阵所有数量相加。
然后在test.py
中进行添加&#xff1a;
t &#61; tuple(x / seen * 1E3 for x in (t0, t1, t0 &#43; t1)) &#43; (imgsz, imgsz, batch_size)
if not training:
print(&#39;Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g&#39; % t)
print("漏检样本数为&#xff1a;")
print(int(confusion_matrix.lou))
print("漏检率为&#xff1a;")
print(confusion_matrix.lou / confusion_matrix.total)
print("虚检样本数为&#xff1a;")
print(int(confusion_matrix.xu))
print("虚检率为&#xff1a;")
print(confusion_matrix.xu / confusion_matrix.total)
if plots:
confusion_matrix.plot(save_dir&#61;save_dir, names&#61;list(names.values()))
if wandb_logger and wandb_logger.wandb:
val_batches &#61; [wandb_logger.wandb.Image(str(f), caption&#61;f.name) for f in sorted(save_dir.glob(&#39;test*.jpg&#39;))]
wandb_logger.log("Validation": val_batches)
if wandb_images:
wandb_logger.log("Bounding Box Debugger/Images": wandb_images)
输出效果&#xff1a;